# =============================================================================
# Statistical Analysis Code for:
# "Chronic adaptive versus conventional DBS response patterns in Parkinson's
#  disease: A pilot randomized crossover trial"
#
# Visualization: Heatmap with LOO Sensitivity Analysis Markers
# Author: Jun Tanimura, MD, MSc.
# Created: 2025-11-01
# Last Updated: 2025-11-25
#
# This script creates an enhanced heatmap that marks non-robust combinations
# (those showing sign change in LOO analysis) with visual indicators.
# Based on: 03_ANCOVA_EXP_Plot_v3.3.R and LOO sensitivity results
# =============================================================================

library(tidyverse)
library(ggplot2)

# =============================================================================
# Load data
# =============================================================================

# Load the main modifier results
# Try multiple possible file names
modifier_file <- NULL
possible_files <- c(
  "treatment_effect_modifiers.csv",
  "results_treatment_effect_modifiers.csv",
  "results_treatment_effect_modifiers_unified_fixed.csv"
)

# Try the explicit list first
if(is.null(modifier_file)) {
  for(f in possible_files) {
    if(file.exists(f)) {
      modifier_file <- f
      break
    }
  }
}

if(is.null(modifier_file)) {
  stop("Modifier results file not found. Please run 03_ANCOVA_EXP_v3.3_20250907.R first or check file path.")
}

cat(sprintf("Loading modifier data from: %s\n", modifier_file))
modifier_data <- read_csv(modifier_file, show_col_types = FALSE)

# Check if column is "Outcome" instead of "Evaluation" and rename if needed
if("Outcome" %in% names(modifier_data) && !"Evaluation" %in% names(modifier_data)) {
  modifier_data <- modifier_data %>% rename(Evaluation = Outcome)
}

# Filter out sex modifier if exists
if("Modifier" %in% names(modifier_data)) {
  modifier_data <- modifier_data %>% filter(!str_detect(Modifier, "^sex"))
}

# Load LOO sensitivity results
loo_summary <- read_csv("results_loo_sensitivity_analysis_summary.csv")

# =============================================================================
# Merge LOO results with modifier data
# =============================================================================

# Create mapping for time variable names between LOO results and main analysis
# LOO results use: ON_time, OFF_time, Dys_sev_time, Dys_weak_time
# Main analysis uses: Mean_corr_ON_time, Mean_corr_OFF_time, Mean_corr_Dys_sev_time, Mean_corr_Dys_weak_time
time_outcome_mapping <- tibble(
  LOO_Outcome = c("ON_time", "OFF_time", "Dys_sev_time", "Dys_weak_time"),
  Main_Evaluation = c("Mean_corr_ON_time", "Mean_corr_OFF_time", 
                      "Mean_corr_Dys_sev_time", "Mean_corr_Dys_weak_time")
)

# Prepare LOO summary with mapped evaluation names
loo_summary_mapped <- loo_summary %>%
  left_join(time_outcome_mapping, by = c("Outcome" = "LOO_Outcome")) %>%
  mutate(
    # Use mapped name if available, otherwise use original Outcome name
    Evaluation_for_merge = ifelse(!is.na(Main_Evaluation), Main_Evaluation, Outcome)
  )

# Merge LOO information
modifier_data_with_loo <- modifier_data %>%
  left_join(
    loo_summary_mapped %>% 
      select(Evaluation_for_merge, Baseline, Sign_change, Robust_status, LOO_range),
    by = c("Evaluation" = "Evaluation_for_merge", "Modifier" = "Baseline")
  ) %>%
  mutate(
    # Mark non-robust combinations
    Non_robust = ifelse(is.na(Sign_change), FALSE, Sign_change),
    LOO_range = ifelse(is.na(LOO_range), NA, LOO_range)
  )

# Clean modifier names (same as original heatmap script)
modifier_data_clean <- modifier_data_with_loo %>%
  mutate(
    # Create Partial_r_Unified if it doesn't exist (use Effect_Size_Unified or calculate from Interaction_T)
    Partial_r_Unified = if("Partial_r_Unified" %in% names(modifier_data_with_loo)) {
      Partial_r_Unified
    } else if("Effect_Size_Unified" %in% names(modifier_data_with_loo)) {
      # Approximate partial r from effect size (Cohen's d)
      # For small effects, partial r ≈ d / sqrt(d^2 + 4)
      Effect_Size_Unified / sqrt(Effect_Size_Unified^2 + 4)
    } else if("Interaction_T" %in% names(modifier_data_with_loo) && "Interaction_DF" %in% names(modifier_data_with_loo)) {
      # Calculate partial r from t-statistic
      Interaction_T / sqrt(Interaction_T^2 + Interaction_DF)
    } else {
      NA_real_
    },
    # Clean modifier names
    Modifier_Clean = case_when(
      str_detect(Modifier, "^age") ~ "Older Age",
      str_detect(Modifier, "^PD_duration") ~ "Longer PD Duration",
      str_detect(Modifier, "^LEDD") ~ "Higher LEDD",
      str_detect(Modifier, "^P_ADL") ~ "Worse ADL",
      str_detect(Modifier, "^P_HY") ~ "Higher Hoehn & Yahr Stage",
      str_detect(Modifier, "^P_UPDRS_part_1") ~ "Worse UPDRS Part 1",
      str_detect(Modifier, "^P_UPDRS_part_2") ~ "Worse UPDRS Part 2", 
      str_detect(Modifier, "^P_UPDRS_part_3") ~ "Worse UPDRS Part 3",
      str_detect(Modifier, "^P_UPDRS_part_4") ~ "Worse UPDRS Part 4",
      str_detect(Modifier, "^P_UPDRS_total") ~ "Worse UPDRS Total",
      str_detect(Modifier, "^P_Mean_corr_ON_time") ~ "Shorter ON Time",
      str_detect(Modifier, "^P_Mean_corr_OFF_time") ~ "Longer OFF Time",
      str_detect(Modifier, "^P_Mean_corr_Dys_weak_time") ~ "Longer Mild Dyskinesia Time",
      str_detect(Modifier, "^P_Mean_corr_Dys_sev_time") ~ "Longer Severe Dyskinesia Time",
      str_detect(Modifier, "^P_MMSE") ~ "Lower MMSE",
      str_detect(Modifier, "^P_PDQ_39") ~ "Higher PDQ39",
      TRUE ~ paste0("Worse ", str_replace(Modifier, "_standardized", ""))
    ),
    
    # Clean evaluation names
    Evaluation_Clean = case_when(
      Evaluation == "Mean_corr_ON_time" ~ "Longer ON Time",
      Evaluation == "Mean_corr_OFF_time" ~ "Shorter OFF Time",
      Evaluation == "Mean_corr_Dys_sev_time" ~ "Shorter Troublesome Dyskinesia",
      Evaluation == "Mean_corr_Dys_weak_time" ~ "Shorter Non-troublesome Dyskinesia", 
      Evaluation == "ADL" ~ "Better ADL",
      Evaluation == "UPDRS_part_1" ~ "Better UPDRS Part 1", 
      Evaluation == "UPDRS_part_2" ~ "Better UPDRS Part 2",
      Evaluation == "UPDRS_part_3" ~ "Better UPDRS Part 3",
      Evaluation == "UPDRS_part_4" ~ "Better UPDRS Part 4", 
      Evaluation == "UPDRS_total" ~ "Better UPDRS Total",
      Evaluation == "MMSE" ~ "Better MMSE",
      Evaluation == "PDQ_39" ~ "Better PDQ_39",
      TRUE ~ paste0("Better ", Evaluation)
    ),
    
    # Effect size categories for symbols
    Effect_Symbol = case_when(
      abs(Partial_r_Unified) > 0.8 ~ "**",
      abs(Partial_r_Unified) > 0.5 ~ "*",
      TRUE ~ ""
    )
  )

# =============================================================================
# Create enhanced heatmap with LOO markers
# =============================================================================

create_loo_enhanced_heatmap <- function(data) {
  
  # Order modifiers by specified order
  modifier_order <- c(
    "Older Age", "Longer PD Duration", "Higher LEDD", "Worse ADL", "Higher Hoehn & Yahr Stage",
    "Worse UPDRS Part 1", "Worse UPDRS Part 2", "Worse UPDRS Part 3", 
    "Worse UPDRS Part 4", "Worse UPDRS Total",
    "Shorter ON Time", "Longer OFF Time", 
    "Longer Mild Dyskinesia Time", "Longer Severe Dyskinesia Time",
    "Lower MMSE", "Higher PDQ39"
  )
  
  # Order evaluations by clinical grouping
  evaluation_order <- c("Better PDQ_39", "Better MMSE", 
                        "Better UPDRS Total", "Better UPDRS Part 4", "Better UPDRS Part 3",
                        "Better UPDRS Part 2", "Better UPDRS Part 1", "Better ADL", 
                        "Shorter Non-troublesome Dyskinesia",
                        "Shorter Troublesome Dyskinesia", "Shorter OFF Time", "Longer ON Time"
  )  
  
  # Prepare data for plotting
  plot_data <- data %>%
    mutate(
      Modifier_Clean = factor(Modifier_Clean, levels = modifier_order),
      Evaluation_Clean = factor(Evaluation_Clean, levels = evaluation_order)
    )
  
  # Create fill value: use Partial_r_Unified for robust, NA (white) for non-robust
  plot_data <- plot_data %>%
    mutate(
      Fill_value = ifelse(Non_robust == TRUE, NA, Partial_r_Unified)
    )
  
  # Create the heatmap
  p_heat <- ggplot(plot_data, aes(x = Modifier_Clean, y = Evaluation_Clean)) +
    
    # Add heatmap tiles with conditional fill (non-robust = white/NA)
    geom_tile(aes(fill = Fill_value), color = "white", size = 0.3) +
    
    # Add symbols for special effects (only for robust combinations)
    geom_text(data = plot_data %>% filter(Non_robust == FALSE),
              aes(label = Effect_Symbol), size = 2, color = "black", fontface = "bold") +
    
    # Color scale (only for robust combinations)
    scale_fill_gradient2(
      low = "#2166AC", 
      mid = "white", 
      high = "#B2182B",
      midpoint = 0,
      name = "Effect Size\n(partial r)",
      limits = c(-1, 1),
      na.value = "white"  # Non-robust combinations will be white
    ) +
    
    # Axis labels and title
    labs(
      title = "Treatment Effect Modifiers with LOO Sensitivity Analysis",
      subtitle = "Effect sizes (partial r): White cells = Non-robust (sign change in LOO), Colored cells = Robust",
      x = "Baseline Predictors",
      y = "Clinical Outcomes",
      caption = "* |r| > 0.5, ** |r| > 0.8 | White = Non-robust (LOO sign change), Colored = Robust"
    ) +
    
    # Theme
    theme_minimal(base_size = 10) +
    theme(
      plot.title = element_text(size = 14, face = "bold"),
      plot.subtitle = element_text(size = 10, color = "gray30"),
      axis.text.x = element_text(angle = 45, hjust = 1, size = 9),
      axis.text.y = element_text(size = 8),
      axis.title = element_text(size = 11, face = "bold"),
      legend.position = "right",
      plot.caption = element_text(size = 8, color = "gray50", hjust = 0),
      panel.grid = element_blank(),
      plot.margin = margin(10, 10, 10, 10)
    ) +
    
    # Add category separators
    geom_vline(xintercept = c(4.5, 9.5, 13.5, 14.5), color = "gray30", size = 0.8, alpha = 0.7)
  
  return(p_heat)
}

# =============================================================================
# Generate and save plots
# =============================================================================

cat("\n=== Creating Enhanced Heatmaps with LOO Markers ===\n")

# Main heatmap with LOO markers
heatmap_with_loo <- create_loo_enhanced_heatmap(modifier_data_clean)
ggsave("figure_treatment_modifier_heatmap_with_loo_markers.pdf",
       heatmap_with_loo, width = 6.5, height = 4.5, bg = "white")

cat("Saved: figure_treatment_modifier_heatmap_with_loo_markers.pdf\n")

# Export merged data for reference
write_csv(modifier_data_clean, 
          "results_treatment_modifiers_with_loo_status.csv")

cat("Saved: Treatment_Modifiers_with_LOO_Status.csv\n")

# Summary statistics
cat("\n=== Summary Statistics ===\n")
cat(sprintf("Total combinations: %d\n", nrow(modifier_data_clean)))
cat(sprintf("Combinations with LOO analysis: %d\n", 
            sum(!is.na(modifier_data_clean$Non_robust))))
cat(sprintf("Non-robust combinations: %d (%.1f%%)\n",
            sum(modifier_data_clean$Non_robust, na.rm = TRUE),
            100 * mean(modifier_data_clean$Non_robust, na.rm = TRUE)))
cat(sprintf("Robust combinations: %d (%.1f%%)\n",
            sum(!modifier_data_clean$Non_robust, na.rm = TRUE),
            100 * mean(!modifier_data_clean$Non_robust, na.rm = TRUE)))

cat("\n=== LOO-Enhanced Heatmap Visualization Complete ===\n")

